import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import to_categorical
from AlexNet import AlexNet

# 加载Fashion-MNIST数据集
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# 数据预处理
train_images = train_images.reshape(-1, 28, 28, 1).astype('float32') / 255.0
test_images = test_images.reshape(-1, 28, 28, 1).astype('float32') / 255.0

# 将图像从28x28调整为227x227
train_images = tf.image.resize(train_images, [227, 227])
test_images = tf.image.resize(test_images, [227, 227])

# 对标签进行one-hot编码
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

# 创建AlexNet模型
model = AlexNet()

# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# 训练模型
model.fit(train_images, train_labels, epochs=5, batch_size=32, validation_data=(test_images, test_labels))

# 保存模型
model.save('./AlexNet/alexnet_fashion_mnist.h5')

# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")
